import jsonlines
import sys
import re


def extract_answer(response, pattern):
    if "<|im_start|>assistant" in response:
        response = response.split("<|im_start|>assistant")[1]
        response = response.replace(" ", "")
    try:
        prediction = response.split("\n\nAnswer:")[1].replace(" ", "")
    except:
        prediction = response
    try:
        prediction = re.findall(pattern, prediction)[0]
    except:
        prediction = "None"
    return prediction


model = sys.argv[1]

direct = [d for d in jsonlines.open(f"./output/{model}.jsonl", "r")]
judge = [d for d in jsonlines.open(f"./output/{model}-v2.jsonl", "r")]

pattern_judge = re.compile("Yes|No")
pattern_direct = re.compile("Hypothesis[1-2]{1}")

judge_dict = {}
for je in judge:
    if je["general_rule"] not in judge_dict:
        judge_dict[je["general_rule"]] = [extract_answer(je["answer"], pattern_judge)=="Yes"]
    else:
        judge_dict[je["general_rule"]].append(extract_answer(je["answer"], pattern_judge)=="Yes")


direct_mapping = {"Hypothesis1": 0, "Hypothesis2": 1, "None": 2}
direct_dict = {}
for de in direct:
    prediction = extract_answer(de["answer"], pattern_direct)
    if de["general_rule"] not in direct_dict:
        direct_dict[de["general_rule"]] = [direct_mapping[prediction]==de['label']]
    else:
        direct_dict[de["general_rule"]].append(direct_mapping[prediction]==de['label'])
    

known, unknown = {}, {}
for key in direct_dict:
    if len(direct_dict[key]) == 1:
        continue
    else:
        if False in judge_dict[key] or "None" in judge_dict[key]:
            unknown[key] = direct_dict[key]
        else:
            known[key] = direct_dict[key]


print(f"[Known Size]: {len(known)}")
print(f"[Unknown Size]: {len(unknown)}")

def abstract_acc(data):
    count_hard, count_soft = 0, 0
    for key in data:
        count_soft += data[key].count(True) / len(data[key])
        count_hard += 1 if data[key].count(True) == len(data[key]) else 0
    return count_hard/len(data), count_soft/len(data)

known_outs = abstract_acc(known)
unknown_outs = abstract_acc(unknown)

print(f"[Known Hard Acc]: {known_outs[0]}")
print(f"[Known Soft Acc]: {known_outs[1]}")
print(f"[Unknown Hard Acc]: {unknown_outs[0]}")
print(f"[Unknown Soft Acc]: {unknown_outs[1]}")

outs = abstract_acc({**known, **unknown})
print(outs)
















